import numpy as np
import scipy.sparse as sp
import torch
from scipy.special import iv as bessel_I
from graph_utils import scipy_to_torch_sparse, degree_from_csr, symmetrize_simple, to_scipy_sparse_matrix

try:
    import faiss  # type: ignore
except Exception as e:
    raise ImportError("FAISS is required (faiss or faiss-cpu).") from e

def set_seed(seed=0):
    np.random.seed(seed); torch.manual_seed(seed)

def _safe_k(n: int, k: int) -> int:
    return max(0, min(int(k), max(0, n - 1)))

def faiss_knn_dense_strict(X: np.ndarray, k: int):
    Xf = np.ascontiguousarray(X.astype(np.float32))
    n = Xf.shape[0]
    k_eff = _safe_k(n, k)
    if n == 0 or k_eff == 0:
        return np.empty((n, 0), dtype=int), np.empty((n, 0), dtype=np.float32)
    kq = min(n, k_eff + 1)
    index = faiss.IndexFlatL2(Xf.shape[1])
    index.add(Xf)
    D_full, I_full = index.search(Xf, kq)
    I_out = np.empty((n, k_eff), dtype=int)
    D_out = np.empty((n, k_eff), dtype=np.float32)
    for i in range(n):
        rowI, rowD = I_full[i], D_full[i]
        mask = rowI != i
        rowI, rowD = rowI[mask], rowD[mask]
        if rowI.size >= k_eff:
            I_out[i] = rowI[:k_eff]; D_out[i] = rowD[:k_eff]
        else:
            need = k_eff - rowI.size
            I_out[i] = np.pad(rowI, (0, need), mode='edge')
            D_out[i] = np.pad(rowD, (0, need), mode='edge')
    return I_out, D_out

def feature_knn_graph_faiss_safe(X: np.ndarray, k=15, sigma=None):
    n = X.shape[0]
    I, D2 = faiss_knn_dense_strict(X, k)
    if sigma is None:
        flat = D2.ravel()
        pos = flat[np.isfinite(flat) & (flat > 0)]
        sigma = float(np.median(pos)) if pos.size else 1.0
        if sigma <= 1e-12: sigma = 1.0
    rows = np.repeat(np.arange(n), I.shape[1])
    cols = I.ravel()
    weights = np.exp(-D2.ravel() / (2.0 * (sigma ** 2)))
    W = sp.csr_matrix((weights, (rows, cols)), shape=(n, n))
    W = W.maximum(W.T); W.setdiag(0); W.eliminate_zeros()
    return W

def diffusion_knn_faiss_dense(Z: np.ndarray, k=15):
    return faiss_knn_dense_strict(Z, k)

def normalized_laplacian(W: sp.csr_matrix):
    W = W.tocsr()
    d = np.asarray(W.sum(axis=1)).ravel()
    d_safe = np.maximum(d, 1e-12)
    Dinv_sqrt = sp.diags(1.0 / np.sqrt(d_safe))
    L = sp.eye(W.shape[0], format='csr') - Dinv_sqrt @ W @ Dinv_sqrt
    return L, d

def mix_laplacian(L_top: sp.csr_matrix, L_feat: sp.csr_matrix, alpha=(0.5, 0.5)):
    a = np.asarray(alpha, dtype=float); a = a/(a.sum()+1e-12); return a[0]*L_top + a[1]*L_feat

def heat_kernel_apply(L: sp.csr_matrix, t=0.6, Omega=None, order=25):
    n = L.shape[0]
    if Omega is None: Omega = np.random.randn(n, 32)
    A = (L - sp.eye(n, format='csr'))
    a = t
    Y = np.zeros_like(Omega, dtype=float)
    T0 = Omega.copy(); T1 = A @ Omega
    Y += bessel_I(0, a) * T0
    Tk_minus_1, Tk = T0, T1
    for k in range(1, order + 1):
        ck = 2.0 * ((-1)**k) * bessel_I(k, a)
        Y += ck * Tk
        Tk_plus_1 = 2.0 * (A @ Tk) - Tk_minus_1
        Tk_minus_1, Tk = Tk, Tk_plus_1
    Y *= np.exp(-t)
    nrm = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-12
    return Y / nrm

def incompatibility_from_dense_knn(I_knn: np.ndarray, degree: np.ndarray):
    n = I_knn.shape[0]
    d = np.maximum(degree, 1e-12)
    rows, cols, vals = [], [], []
    in_ball = [set(I_knn[i].tolist() + [i]) for i in range(n)]
    for i in range(n):
        comp = [j for j in range(n) if j not in in_ball[i]]
        if not comp: continue
        wij = (1.0 / d[i]) * (1.0 / d[np.asarray(comp)])
        rows.extend([i]*len(comp)); cols.extend(comp); vals.extend(wij.tolist())
    M = sp.csr_matrix((vals, (rows, cols)), shape=(n, n))
    M = M.maximum(M.T); M.setdiag(0); M.eliminate_zeros()
    return M

@torch.no_grad()
def estimate_lmax_power_sparse(L_sp: torch.Tensor, iters: int = 20):
    n = L_sp.size(0)
    x = torch.randn(n, 1, device=L_sp.device, dtype=L_sp.dtype)
    x = x / (x.norm() + 1e-12)
    lam = None
    for _ in range(iters):
        y = torch.sparse.mm(L_sp, x)
        ny = y.norm()
        if ny.item() == 0.0:
            return torch.tensor(0.0, device=L_sp.device, dtype=L_sp.dtype)
        x = y / ny
        lam = (x.t() @ torch.sparse.mm(L_sp, x)).item()
    return torch.tensor(lam if lam is not None else 0.0, device=L_sp.device, dtype=L_sp.dtype)

@torch.no_grad()
def block_power_smallest(L_sp: torch.Tensor, K: int, iters: int = 12, deg_C: torch.Tensor | None = None):
    device, dtype = L_sp.device, L_sp.dtype
    nC = L_sp.size(0)
    lmax = estimate_lmax_power_sparse(L_sp, iters=12).clamp_min(1e-12)
    def apply_B(X):
        return X - (1.0 / lmax) * torch.sparse.mm(L_sp, X)

    if deg_C is None:
        deg_C = torch.ones(nC, device=device, dtype=dtype)
    s = deg_C.clamp_min(1e-12).sqrt().unsqueeze(1)
    s = s / (s.norm() + 1e-12)

    Q = torch.randn(nC, K, device=device, dtype=dtype)
    Q, _ = torch.linalg.qr(Q, mode='reduced')
    for _ in range(iters):
        Q = apply_B(Q)
        Q = Q - s * (s.t() @ Q)
        Q, _ = torch.linalg.qr(Q, mode='reduced')
    return Q

class LinearHeadK(torch.nn.Module):
    def __init__(self, K, hidden=32, dtype=torch.float64):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(K, hidden, dtype=dtype), torch.nn.ReLU(),
            torch.nn.Linear(hidden, K, dtype=dtype)
        )
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
                torch.nn.init.zeros_(m.bias)
    def forward(self, U): return self.net(U)

def sinkhorn_balanced(logits: torch.Tensor, n_iters=7, tau=1.0):
    n, K = logits.shape
    P = torch.exp(logits / tau) + 1e-9
    col_tgt = (n / K) * torch.ones(K, device=logits.device, dtype=logits.dtype)
    for _ in range(n_iters):
        P = P / (P.sum(dim=1, keepdim=True) + 1e-12)
        col_sum = P.sum(dim=0) + 1e-12
        P = P * (col_tgt / col_sum)
    P = P / (P.sum(dim=1, keepdim=True) + 1e-12)
    return P

@torch.no_grad()
def assignments_with_margin(U: torch.Tensor, hidden=32, zeta=1e-2, method="sinkhorn", tau=1.0, sinkhorn_iters=12):
    K = U.size(1)
    head = LinearHeadK(K=K, hidden=hidden, dtype=U.dtype).to(U.device)
    logits = head(U)
    logits = logits - logits.mean(dim=1, keepdim=True)
    top2 = torch.topk(logits, k=2, dim=1).values
    num = (top2[:, 0] - top2[:, 1]).clamp_min(0.0)
    den = top2.abs().sum(dim=1) + zeta
    mu = (num / den).unsqueeze(1)
    logits_scaled = mu * logits
    if method.lower() == "sinkhorn":
        S = sinkhorn_balanced(logits_scaled, n_iters=sinkhorn_iters, tau=tau)
    else:
        S = torch.softmax(logits_scaled, dim=1)
    return S, logits, mu.squeeze(1)

def decide_K(N_cur, ratio, last_level):
    if N_cur <= 2 or last_level: return 1
    K = int(N_cur * ratio) + 1
    return max(1, min(K, N_cur))

def coarsen_adj_hard(A: sp.csr_matrix, hard_labels: np.ndarray, K: int):
    rr, cc, vv = sp.find(A)
    nrr, ncc = hard_labels[rr], hard_labels[cc]
    A_coarse = sp.csr_matrix((vv, (nrr, ncc)), shape=(K, K))
    return symmetrize_simple(A_coarse)

def _unique_seeds_for_all_clusters(S: np.ndarray) -> np.ndarray:
    N, K = S.shape
    order = np.argsort(-S, axis=0)
    seeds = -np.ones(K, dtype=int); used = np.zeros(N, dtype=bool); ptr = np.zeros(K, dtype=int)
    remaining = list(range(K)); guard = 0
    while remaining and guard < K * N:
        k = remaining.pop(0)
        while ptr[k] < N and used[order[ptr[k], k]]: ptr[k] += 1
        if ptr[k] < N:
            i = order[ptr[k], k]; seeds[k] = i; used[i] = True
        else:
            i = int(np.argmin(used)); seeds[k] = i; used[i] = True
        guard += 1
    return seeds

def hard_labels_cover_all(S: np.ndarray) -> np.ndarray:
    N, K = S.shape
    y = S.argmax(axis=1)
    counts = np.bincount(y, minlength=K)
    if (counts == 0).any():
        seeds = _unique_seeds_for_all_clusters(S)
        y[seeds] = np.arange(K, dtype=int)
    return y

def Make_tree_HMH(
    X, A, levels: int, ratio: float = 0.2, lam: float = 0.1,
    k_feat: int = 15, k_diff: int = 15, t_heat: float = 0.6, cheb_order: int = 25,
    alpha=(0.5, 0.5), device: str = "cpu", dtype = torch.float64,
    assign_method: str = "sinkhorn", tau: float = 1.0, sinkhorn_iters: int = 7, seed: int = 0,
):
    set_seed(seed)
    if isinstance(X, torch.Tensor): X = X.detach().cpu().numpy()
    A = A.tocsr()
    N_start = A.shape[0]

    adj_list = [A]; features_list = [X]; parents = []; S_assign_list = []
    for level in range(levels - 1):
        N_cur = A.shape[0]; last_level = (level == levels - 2)
        K = decide_K(N_cur, ratio, last_level)
        if K == 1:
            S_triv = np.ones((N_cur, 1), dtype=np.float64)
            S_assign_list.append(S_triv); parents.append(np.zeros(N_cur, dtype=int))
            X = S_triv.T @ X
            A = coarsen_adj_hard(A, np.zeros(N_cur, dtype=int), K=1)
            adj_list.append(A); features_list.append(X); break

        L_top, deg_top = normalized_laplacian(A)
        W_feat = feature_knn_graph_faiss_safe(X, k=k_feat, sigma=None)
        L_feat, _ = normalized_laplacian(W_feat)
        L_mix = mix_laplacian(L_top, L_feat, alpha=alpha)
        Z = heat_kernel_apply(L_mix, t=t_heat, order=cheb_order)
        I_diff, _ = diffusion_knn_faiss_dense(Z, k=k_diff)
        M_C = incompatibility_from_dense_knn(I_diff, degree=deg_top)

        Lmix_sp_t = scipy_to_torch_sparse(L_mix, device=device, dtype=dtype)
        MC_sp_t   = scipy_to_torch_sparse(M_C,   device=device, dtype=dtype)
        L_aug_t   = (Lmix_sp_t + lam * MC_sp_t).coalesce()

        deg_t = torch.from_numpy(deg_top).to(device=device, dtype=dtype)
        U_t = block_power_smallest(L_aug_t, K=K, iters=12, deg_C=deg_t)
        S_t, logits_t, mu_t = assignments_with_margin(
            U_t, hidden=32, zeta=1e-3, method=assign_method, tau=tau, sinkhorn_iters=sinkhorn_iters
        )
        S_np = S_t.detach().cpu().numpy()

        hard_labels = hard_labels_cover_all(S_np)
        A_next = coarsen_adj_hard(A, hard_labels, K)
        adj_list.append(A_next)

        X_next = S_np.T @ X
        features_list.append(X_next)

        parents.append(hard_labels)
        S_assign_list.append(S_np)

        A, X = A_next, X_next

    L_eff = len(adj_list)
    treeG = [None] * L_eff
    for lvl in range(L_eff):
        if lvl == 0:
            idxs = np.arange(N_start)
            clusters = [np.array([i], dtype=int) for i in idxs]
            IDX_vec = np.arange(N_start)
        else:
            pid = parents[lvl - 1]
            K_lvl = S_assign_list[lvl - 1].shape[1]
            clusters = [np.flatnonzero(pid == k) for k in range(K_lvl)]
            IDX_vec = pid
        treeG[lvl] = {'IDX': IDX_vec, 'clusters': clusters, 'adj': adj_list[lvl], 'features': features_list[lvl]}
    return treeG, S_assign_list

def Uext_batch_from_tree_lists_HMH(
    X_list, edge_index_list, levels=5, ratio=0.3,
    lam=0.1, k_feat=15, k_diff=15, t_heat=0.6, cheb_order=25,
    alpha=(0.5,0.5), device="cpu", dtype=torch.float64,
    assign_method="sinkhorn", tau=0.9, sinkhorn_iters=10, seed=42
):
    U_batch, edge_index_list_batch, num_nodes_tree_batch, num_edges_tree_batch, features_list_batch, treeG_batch, S_assign_List = [], [], [], [], [], [], []
    from graph_utils import adj2edge
    for X_i, ei_i in zip(X_list, edge_index_list):
        A_i = to_scipy_sparse_matrix(ei_i, num_nodes=X_i.shape[0])
        treeG_i, S_assign_list = Make_tree_HMH(
            X=X_i, A=A_i, levels=levels, ratio=ratio, lam=lam,
            k_feat=k_feat, k_diff=k_diff, t_heat=t_heat, cheb_order=cheb_order,
            alpha=alpha, device=device, dtype=dtype,
            assign_method=assign_method, tau=tau, sinkhorn_iters=sinkhorn_iters, seed=seed
        )

        Tree_length = len(treeG_i)
        num_nodes_tree = np.zeros(Tree_length, dtype=int)
        num_edges_tree = np.zeros(Tree_length, dtype=int)
        eidx_list = [None] * Tree_length
        for j in range(Tree_length):
            num_nodes_tree[j] = len(treeG_i[j]['clusters'])
            eidx, _ = adj2edge(treeG_i[j]['adj'])
            eidx_list[j] = eidx
            num_edges_tree[j] = eidx.size(1)

        U_batch.append([None]*(Tree_length-1))
        edge_index_list_batch.append(eidx_list)
        num_nodes_tree_batch.append(num_nodes_tree)
        num_edges_tree_batch.append(num_edges_tree)
        features_list_batch.append([treeG_i[j]['features'] for j in range(Tree_length)])
        treeG_batch.append(treeG_i)
        S_assign_List.append(S_assign_list)
    return (U_batch, edge_index_list_batch, num_nodes_tree_batch, num_edges_tree_batch, features_list_batch, treeG_batch, S_assign_List)

